import random
from collections.abc import Sequence, Mapping

import torch

from Utils.Registry import Registry
from Utils.Misc import worker_init_fn
from torch.utils.data import DataLoader

datasets = Registry("Datasets")


def collate_fn(batch):
    if not isinstance(batch, Sequence):
        raise TypeError(f"{batch.dtype} is not supported.")

    if isinstance(batch[0], torch.Tensor):
        return torch.stack(list(batch))
    elif isinstance(batch[0], str):
        # str is also a kind of Sequence, judgement should before Sequence
        return list(batch)
    elif isinstance(batch[0], Sequence):
        for data in batch:
            data.append(torch.tensor([data[0].shape[0]]))
        batch = [collate_fn(samples) for samples in zip(*batch)]
        batch[-1] = torch.cumsum(batch[-1], dim=0).int()
        return batch
    elif isinstance(batch[0], Mapping):
        if batch[0].get('mode', None) == 'train':
            num_points = batch[0].get('num_points')
            batch = [item for item in batch if item['offset'] == num_points]
            if len(batch) <= 1:
                return None

        batch = {key: collate_fn([d[key] for d in batch]) for key in batch[0]}
        for key in batch.keys():
            if "offset" in key:
                batch[key] = torch.cumsum(batch[key], dim=0)
        return batch


def point_collate_fn(batch):
    # currently, only support input_dict, rather than input_list
    assert isinstance(
        batch[0], Mapping
    )
    batch = collate_fn(batch)
    return batch


def build_dataset(cfgs, default_args=None):
    """
        Build a dataset, defined by `dataset_name`.
        Args:
            cfgs :
            default_args :
        Returns:
            Dataset: a constructed dataset specified by dataset_name.
    """
    return datasets.build(cfgs, default_args=default_args)


def build_dataloader(cfgs):

    return DataLoader(
        build_dataset(cfgs.dataset),
        batch_size=cfgs.batch_size,
        num_workers=int(cfgs.num_workers),
        shuffle=cfgs.dataset.mode != "test",
        drop_last=cfgs.dataset.mode != "test",
        worker_init_fn=worker_init_fn,
        pin_memory=True,
        # multiprocessing_context='spawn',
        persistent_workers=True,
        collate_fn=point_collate_fn,
    )
